# -*- coding: utf-8 -*-
"""
@Time    : 2024/8/27 15:47 
@Author  : ZhangShenao 
@File    : dalle_image_tool.py 
@Desc    : OpenAI DALLE文生图工具
"""
from langchain_community.tools.openai_dalle_image_generation import OpenAIDALLEImageGenerationTool
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
from langchain_core.pydantic_v1 import BaseModel, Field


class DallEArgsSchema(BaseModel):
    """生成图片的参数定义"""
    query: str = Field(description="生成图像的文本提示(prompt)")


def create_dalle_image_tool() -> OpenAIDALLEImageGenerationTool:
    """创建DALLE文生图工具"""

    return OpenAIDALLEImageGenerationTool(
        name="openai_dalle",  # 指定工具名称
        description="OpenAI DALLE 文生图工具",  # 指定工具描述
        api_wrapper=DallEAPIWrapper(model="dall-e-3"),  # 指定API包装器
        args_schema=DallEArgsSchema,  # 指定工具参数定义
    )
